import numpy as np
import sys
from simNeuralEQ import *
import argparse
import pickle
'''****************************************************
Function define for saving and loading list.
It is used for processing fwdBwd output materials.
****************************************************'''
def saveList(fileName, l):
	with open(fileName, "wb") as fp:
		pickle.dump(l, fp)

def loadList(fileName):
	with open(fileName, "rb") as fp:
		out = pickle.load(fp)
	return out


'''*******************************************************
Function for post processing training input according to parameters(lossFn, mod, simpleDataTraining...)
*******************************************************'''
def trainSetPostProcess (lossFn, nrzNnOutOne, simpleDataTraining, fwdBwdProbTrain, chInTrain, mod):
	if mod == 'nrz':
		modNum = 2
	elif mod == 'pam4':
		modNum = 4
	elif mod == 'pam8':
		modNum = 8
	else:
		sys.exit('invalid modulation')


	if lossFn == 'crossEntropy': # Cross Entropy loss
		#fwdBwdProbTrain = np.argmax(fwdBwdProbTrain,axis=1) 
		if nrzNnOutOne == True:
			print("nrzNnOutOne & crossEntropy do not work")
			sys.exit()
		if simpleDataTraining:
			fwdBwdProbTrain = np.zeros(len(chInTrain))
			if mod == 'nrz':
				fwdBwdProbTrain = np.where(chInTrain==1, 0, 1)
			elif mod == 'pam4':
				for k in range(len(fwdBwdProbTrain)):
					if chInTrain[k] == 1:
						fwdBwdProbTrain[k] = np.array(0)
					elif chInTrain[k] >= 0:
						fwdBwdProbTrain[k] = np.array(1)
					elif chInTrain[k] > -1:
						fwdBwdProbTrain[k] = np.array(2)
					else:
						fwdBwdProbTrain[k] = np.array(3)
						
		else:	
			fwdBwdProbTrain = np.argmax(fwdBwdProbTrain,axis=1)

	#@@ manualCrossEntropy case. Bypass for not simpleDataTraining case.
	elif lossFn == 'manualCrossEntropy': # Manual Cross Entropy
		if simpleDataTraining:
			#fwdBwdProbTrain = np.where(chInTrain==1, [0,1], [1,0])
			if nrzNnOutOne == True and mod == 'nrz':
				fwdBwdProbTrain = np.zeros((len(chInTrain),1))
			else:
				fwdBwdProbTrain = np.zeros((len(chInTrain),modNum))
			if mod == 'nrz':
				for k in range(len(fwdBwdProbTrain)):
					if chInTrain[k]==1:
						if nrzNnOutOne == True:
							fwdBwdProbTrain[k] = np.array([1])
						else:
							fwdBwdProbTrain[k] = np.array([1,0])
					else:
						if nrzNnOutOne == True:
							fwdBwdProbTrain[k] = np.array([0])
						else:
							fwdBwdProbTrain[k] = np.array([0,1])
			elif mod == 'pam4':
				for k in range(len(fwdBwdProbTrain)):
					if chInTrain[k]==1:
						fwdBwdProbTrain[k] = np.array([1,0,0,0])
					elif chInTrain[k] >= 0:
						fwdBwdProbTrain[k] = np.array([0,1,0,0])
					elif chInTrain[k] > -1:
						fwdBwdProbTrain[k] = np.array([0,0,1,0])
					else:
						fwdBwdProbTrain[k] = np.array([0,0,0,1])
						
		else:
			if nrzNnOutOne == True and mod == 'nrz':
				fwdBwdProbTrain = fwdBwdProbTrain[:,0]
		#print (fwdBwdProbTrain[0])
	#print (fwdBwdProbTrain.shape)
	#os.exit()

	#@@ mse case. Bypass if simpleDataTraining=0
	elif lossFn == 'mse':	# MSE loss
		if simpleDataTraining:
			if nrzNnOutOne == True and mod == 'nrz':
				fwdBwdProbTrain = np.zeros((len(chInTrain),1))
			else:
				fwdBwdProbTrain = np.zeros((len(chInTrain),modNum))
			if mod == 'nrz':
				for k in range(len(fwdBwdProbTrain)):
					if chInTrain[k]==1:
						if nrzNnOutOne == True:
							fwdBwdProbTrain[k] = np.array([1])
						else:
							fwdBwdProbTrain[k] = np.array([1,0])
					else:
						if nrzNnOutOne == True:
							fwdBwdProbTrain[k] = np.array([0])
						else:
							fwdBwdProbTrain[k] = np.array([0,1])
			elif mod == 'pam4':
				for k in range(len(fwdBwdProbTrain)):
					if chInTrain[k]==1:
						fwdBwdProbTrain[k] = np.array([1,0,0,0])
					elif chInTrain[k] >= 0:
						fwdBwdProbTrain[k] = np.array([0,1,0,0])
					elif chInTrain[k] > -1:
						fwdBwdProbTrain[k] = np.array([0,0,1,0])
					else:
						fwdBwdProbTrain[k] = np.array([0,0,0,1])
		else:
			if nrzNnOutOne == True and mod == 'nrz':
				#print(fwdBwdProbTrain[:20])
				fwdBwdProbTrain = fwdBwdProbTrain[:,0]
				fwdBwdProbTrain = fwdBwdProbTrain.reshape(-1,1)
				#print(fwdBwdProbTrain[:20])

		#@@ Mean value to 0 by substract 0.5. Is it helpful?
		fwdBwdProbTrain = (np.array(fwdBwdProbTrain)-0.5)#*1.99999
		#print(fwdBwdProbTrain[:20])
		#print(f'mod: {mod}')
		#print(f'modNum: {modNum}')
		#print(f'chInTrain: {chInTrain}')
		#print(f'fwdBwdProbTrain : {fwdBwdProbTrain}')
		#sys.exit()

	fwdBwdProbTrain = list(fwdBwdProbTrain)
	return fwdBwdProbTrain

'''***********************************************
Train & Eval
***********************************************'''
def trainEval(
	nEQ,		#@@ neuralEQ model 
	tx,			#@@ TX Class
	chInValid,	#@@ Valid data set (label)
	chOutValid,	#@@	Valid data set (input)
	numEpoch,	#@@ EPOCH
	evalFreq,	#@@ Evaluation frequency
	mod,
	chSBR,
	inSize,
	outSize,
	batchSize,
	delay,
	lossFn,
	opt,
	dataSizeTrain,
	snrTrain,
	flagN,
	chInTrain = None,
	chOutTrain = None,
	trainSnrVariation = False,
	):

	simNEQ = simNeuralEQ(
						txDataTrain=chInTrain, 
						rxDataTrain=chOutTrain, 
						txDataTest=chInValid, 
						rxDataTest=chOutValid, 
						neuralEQ=nEQ, 
						mod=mod
						)
	trainLossHis = []
	validLossHis = []
	berValidHis = []
	for k in range(numEpoch):
		if chInTrain is None:	#@@ On the fly generation
			if (trainSnrVariation):
				if (k>  200):
					chInTrainOnTheFly = tx.run(dataSizeTrain)
					chOnTheFly = Channel(sbr=chSBR, snr=snrTrain)	
					chOutTrainOnTheFly = chOnTheFly.run(chIn = chInTrainOnTheFly, flagN=flagN)
				else:
					chInTrainOnTheFly = tx.run(dataSizeTrain)
					chOnTheFly = Channel(sbr=chSBR, snr=snrTrain)	
					chOutTrainOnTheFly = chOnTheFly.run(chIn = chInTrainOnTheFly, flagN=0)
			else:
				chInTrainOnTheFly = tx.run(dataSizeTrain)
				chOnTheFly = Channel(sbr=chSBR, snr=snrTrain)	
				chOutTrainOnTheFly = chOnTheFly.run(chIn = chInTrainOnTheFly, flagN=flagN)
			chInTrainOnTheFly = trainSetPostProcess (
													lossFn, 
													False, 
													1, 
													chInTrainOnTheFly, 
													chInTrainOnTheFly, 
													mod
													)
			if 0:
				print(f"lossFn: {lossFn}")
				print(f"chInTrainOnTheFly: {chInTrainOnTheFly[0:10]}")
				sys.exit()
		else:
			chInTrainOnTheFly = trainSetPostProcess (
													lossFn, 
													False, 
													1, 
													chInTrain, 
													chInTrain, 
													mod)
			chOutTrainOnTheFly = chOutTrain

		if 0:
			print(f"batchSize: {batchSize}")
			print(f"chOutTrainOnTheFly: {chOutTrainOnTheFly}")
			print(f"chInTrainOnTheFly: {chInTrainOnTheFly}")
		loss = simNEQ.trainNeuralEQ(
									lossFn, 
									opt, 
									batchSize=batchSize, 
									inSize=inSize, 
									outSize=outSize, 
									delay=delay, 
									rxDataTrainNew=chOutTrainOnTheFly, 
									txDataTrainNew=chInTrainOnTheFly
									)
		trainLossHis.append(loss)
		print(f"trainloss: {loss:e},   epoch:{k}/{numEpoch-1}", flush=True)
		if (k % evalFreq == evalFreq-1):
			validLoss, berValid = simNEQ.evalNeuralEQ(
													lossFn, 
													batchSize=batchSize, 
													inSize=inSize, 
													outSize=outSize, 
													delay=delay, 
													rxDataTestNew=chOutValid, 
													txDataTestNew=chInValid
													)
			print(f"validloss: {validLoss:e}, validber: {berValid:e}, epoch:{k}/{numEpoch-1}", flush=True)
			validLossHis.append(validLoss)
			berValidHis.append(berValid)
	return trainLossHis, validLossHis, berValidHis

'''***********************************************
Quantization 
***********************************************'''

#def quantIn(x, 


'''****************************************************
Parser define.
name: used for naming result directory
config: input config file path
****************************************************'''
def parsing_def():
	parser      = argparse.ArgumentParser(description="This script is for generate verilog-A when port and initial value are given")
	parser.add_argument('-n', '--name', type=str, default='temp', required=True)
	parser.add_argument('-c', '--config', type=str, default='', required=False)
	args        = parser.parse_args()
	return args



'''****************************************************
Reset parameters for input nEq model
****************************************************'''
def reset_parameters(nEq, row, col):
	for k in range(row):
		for i in range(col):
			for j in range(len(nEq.nnUnit[k][i])):
				if isinstance(nEq.nnUnit[k][i][j], torch.nn.Linear):
					nEq.nnUnit[k][i][j].reset_parameters()
	for j in range(len(nEq.nnFinalUnit)):
		if isinstance(nEq.nnFinalUnit[j], torch.nn.Linear):
			nEq.nnFinalUnit[j].reset_parameters()


	
